{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from os.path import join as pjoin\n",
    "\n",
    "from data_loader.data_utils import *\n",
    "from utils.math_graph import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 228\n",
    "n_his = 12\n",
    "n_pred = 9\n",
    "\n",
    "W = weight_matrix(pjoin('./data_loader/data', f'W_228.csv'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(228, 228)\n"
     ]
    }
   ],
   "source": [
    "print(W.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0. 0. 0. ... 0. 0. 0.]\n",
      " [0. 0. 0. ... 0. 0. 0.]\n",
      " [0. 0. 0. ... 0. 0. 0.]\n",
      " ...\n",
      " [0. 0. 0. ... 0. 0. 0.]\n",
      " [0. 0. 0. ... 0. 0. 0.]\n",
      " [0. 0. 0. ... 0. 0. 0.]]\n"
     ]
    }
   ],
   "source": [
    "print(W)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_file = f'V_228.csv'\n",
    "n_train, n_val, n_test = 34, 5, 5\n",
    "PeMS = data_gen(pjoin('./data_loader/data', data_file), (n_train, n_val, n_test), n, n_his + n_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(9112, 21, 228, 1)\n"
     ]
    }
   ],
   "source": [
    "print(PeMS.get_data('train').shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1340, 21, 228, 1)\n"
     ]
    }
   ],
   "source": [
    "print(PeMS.get_data('test').shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(21, 228, 1)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train = PeMS.get_data('train')\n",
    "train[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(228, 1)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train[0][0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.91780696],\n",
       "       [ 0.54632013],\n",
       "       [ 0.44434335],\n",
       "       [ 0.51718391],\n",
       "       [ 0.62644474],\n",
       "       [ 0.9760794 ],\n",
       "       [ 0.73570557],\n",
       "       [ 0.67014907],\n",
       "       [ 0.53175202],\n",
       "       [-1.28197779],\n",
       "       [ 0.66286502],\n",
       "       [ 0.82311424],\n",
       "       [ 0.50989985],\n",
       "       [ 0.83039829],\n",
       "       [ 0.34965063],\n",
       "       [ 0.82311424],\n",
       "       [ 0.47347957],\n",
       "       [ 0.41520713],\n",
       "       [ 0.9760794 ],\n",
       "       [ 0.87410262],\n",
       "       [ 0.50989985],\n",
       "       [ 0.80854613],\n",
       "       [ 0.81583018],\n",
       "       [ 0.47347957],\n",
       "       [ 0.70656935],\n",
       "       [ 0.47347957],\n",
       "       [ 0.80126207],\n",
       "       [ 0.47347957],\n",
       "       [ 0.61916068],\n",
       "       [ 0.41520713],\n",
       "       [ 0.70656935],\n",
       "       [ 0.95422723],\n",
       "       [ 0.47347957],\n",
       "       [ 0.85953451],\n",
       "       [ 0.69200124],\n",
       "       [ 0.47347957],\n",
       "       [ 0.78669396],\n",
       "       [ 0.47347957],\n",
       "       [ 0.74298963],\n",
       "       [ 0.94694318],\n",
       "       [ 0.77212585],\n",
       "       [ 0.7794099 ],\n",
       "       [ 0.75755774],\n",
       "       [ 0.47347957],\n",
       "       [ 0.47347957],\n",
       "       [ 0.74298963],\n",
       "       [ 0.99064751],\n",
       "       [ 0.82311424],\n",
       "       [ 0.77212585],\n",
       "       [ 0.72842152],\n",
       "       [ 0.47347957],\n",
       "       [ 0.47347957],\n",
       "       [ 0.47347957],\n",
       "       [ 0.47347957],\n",
       "       [ 0.47347957],\n",
       "       [ 0.5026158 ],\n",
       "       [ 0.65558096],\n",
       "       [ 0.81583018],\n",
       "       [ 0.74298963],\n",
       "       [ 0.7138534 ],\n",
       "       [ 0.70656935],\n",
       "       [ 0.47347957],\n",
       "       [ 0.47347957],\n",
       "       [ 0.19668547],\n",
       "       [ 0.47347957],\n",
       "       [ 0.47347957],\n",
       "       [ 0.63372879],\n",
       "       [ 0.81583018],\n",
       "       [ 0.47347957],\n",
       "       [ 0.81583018],\n",
       "       [ 0.47347957],\n",
       "       [ 0.83039829],\n",
       "       [ 0.47347957],\n",
       "       [ 0.63372879],\n",
       "       [ 0.47347957],\n",
       "       [ 0.73570557],\n",
       "       [ 0.6482969 ],\n",
       "       [ 0.82311424],\n",
       "       [ 0.47347957],\n",
       "       [ 0.47347957],\n",
       "       [ 0.7138534 ],\n",
       "       [ 0.91780696],\n",
       "       [ 0.79397801],\n",
       "       [ 0.59730852],\n",
       "       [ 0.7138534 ],\n",
       "       [ 0.89595479],\n",
       "       [ 0.7138534 ],\n",
       "       [ 0.90323885],\n",
       "       [ 0.47347957],\n",
       "       [ 0.61187663],\n",
       "       [ 0.65558096],\n",
       "       [ 0.29866224],\n",
       "       [ 0.5026158 ],\n",
       "       [ 0.87410262],\n",
       "       [ 0.95422723],\n",
       "       [ 1.20916918],\n",
       "       [ 0.9105229 ],\n",
       "       [ 0.47347957],\n",
       "       [ 0.47347957],\n",
       "       [ 0.47347957],\n",
       "       [ 0.69200124],\n",
       "       [ 0.47347957],\n",
       "       [ 0.93965912],\n",
       "       [ 0.47347957],\n",
       "       [ 0.9760794 ],\n",
       "       [ 0.61187663],\n",
       "       [ 0.76484179],\n",
       "       [ 0.40063902],\n",
       "       [ 0.85225046],\n",
       "       [ 0.45891146],\n",
       "       [ 0.56817229],\n",
       "       [ 0.56088824],\n",
       "       [ 0.93237507],\n",
       "       [ 0.59002446],\n",
       "       [ 0.69928529],\n",
       "       [ 0.61916068],\n",
       "       [ 0.52446796],\n",
       "       [ 0.67014907],\n",
       "       [ 0.59002446],\n",
       "       [ 0.69200124],\n",
       "       [ 0.80854613],\n",
       "       [ 0.78669396],\n",
       "       [ 1.01978373],\n",
       "       [ 0.63372879],\n",
       "       [ 0.61187663],\n",
       "       [ 0.72842152],\n",
       "       [ 0.9760794 ],\n",
       "       [ 0.9105229 ],\n",
       "       [-0.53900414],\n",
       "       [ 0.75755774],\n",
       "       [ 0.85953451],\n",
       "       [ 0.72113746],\n",
       "       [ 0.94694318],\n",
       "       [ 0.51718391],\n",
       "       [ 0.82311424],\n",
       "       [ 0.75755774],\n",
       "       [ 0.93965912],\n",
       "       [ 0.91780696],\n",
       "       [ 0.65558096],\n",
       "       [ 0.69928529],\n",
       "       [ 1.08534023],\n",
       "       [ 0.91780696],\n",
       "       [ 0.83768235],\n",
       "       [ 0.79397801],\n",
       "       [ 0.80854613],\n",
       "       [ 0.46619552],\n",
       "       [ 0.64101285],\n",
       "       [ 0.67014907],\n",
       "       [ 0.68471718],\n",
       "       [ 0.87410262],\n",
       "       [ 0.76484179],\n",
       "       [ 0.93237507],\n",
       "       [ 0.99793157],\n",
       "       [ 0.47347957],\n",
       "       [ 0.47347957],\n",
       "       [ 0.47347957],\n",
       "       [ 0.46619552],\n",
       "       [ 0.9760794 ],\n",
       "       [ 0.47347957],\n",
       "       [ 0.93237507],\n",
       "       [ 0.42249119],\n",
       "       [ 0.96879535],\n",
       "       [ 0.47347957],\n",
       "       [ 0.70656935],\n",
       "       [ 0.47347957],\n",
       "       [ 0.47347957],\n",
       "       [ 0.47347957],\n",
       "       [ 0.83768235],\n",
       "       [ 0.47347957],\n",
       "       [ 0.77212585],\n",
       "       [ 0.72842152],\n",
       "       [ 0.76484179],\n",
       "       [ 0.99793157],\n",
       "       [ 0.83768235],\n",
       "       [ 0.5026158 ],\n",
       "       [ 0.82311424],\n",
       "       [ 0.47347957],\n",
       "       [ 0.54632013],\n",
       "       [ 0.47347957],\n",
       "       [ 0.47347957],\n",
       "       [ 0.90323885],\n",
       "       [ 0.47347957],\n",
       "       [ 0.47347957],\n",
       "       [ 0.60459257],\n",
       "       [ 0.80854613],\n",
       "       [ 0.47347957],\n",
       "       [ 0.47347957],\n",
       "       [ 0.7794099 ],\n",
       "       [ 0.47347957],\n",
       "       [ 0.78669396],\n",
       "       [ 0.9105229 ],\n",
       "       [ 0.47347957],\n",
       "       [ 0.7138534 ],\n",
       "       [ 0.38607091],\n",
       "       [ 1.00521562],\n",
       "       [ 0.67014907],\n",
       "       [ 0.93965912],\n",
       "       [ 0.85953451],\n",
       "       [ 0.92509101],\n",
       "       [ 0.9105229 ],\n",
       "       [ 0.47347957],\n",
       "       [ 0.77212585],\n",
       "       [ 0.47347957],\n",
       "       [ 0.47347957],\n",
       "       [ 0.4370593 ],\n",
       "       [ 0.5026158 ],\n",
       "       [ 0.47347957],\n",
       "       [ 0.67743313],\n",
       "       [ 0.85953451],\n",
       "       [ 0.88867074],\n",
       "       [ 0.98336346],\n",
       "       [ 0.59730852],\n",
       "       [ 0.47347957],\n",
       "       [ 0.47347957],\n",
       "       [ 0.73570557],\n",
       "       [ 0.26224197],\n",
       "       [ 1.00521562],\n",
       "       [ 0.56088824],\n",
       "       [ 0.77212585],\n",
       "       [ 0.90323885],\n",
       "       [ 0.47347957],\n",
       "       [ 0.4370593 ],\n",
       "       [ 0.59002446],\n",
       "       [ 0.59002446],\n",
       "       [ 0.47347957],\n",
       "       [ 0.78669396],\n",
       "       [ 0.67014907],\n",
       "       [ 0.75755774]])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train[0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
